# coding:utf-8

import sys
from src.constant.file_and_path_constant import FileAndPathConstant

sys.path.append(FileAndPathConstant.System_Drive + '\github-repository\hades\dev-project\jormungandr')

from src.manager.oracle_manager import OracleManager
from src.manager.lstm_manager import LstmManager
from src.constant.oracle import Oracle
from src.manager.log_manager import LogManager
from pandas.core.frame import DataFrame
import numpy as np

Logger = LogManager.get_logger(__name__)


class LstmHandler:
    """
    长短期记忆网络处理器
    """

    def __init__(self):
        """
        构造函数，初始化oracle_manager和cursor对象
        """
        self.oracle_manager = OracleManager(Oracle.Username, Oracle.Password, Oracle.Url)
        self.oracle_manager.connect()
        self.cursor = self.oracle_manager.get_cursor()

    def close_cursor_and_connect(self):
        """
        关闭游标和数据库连接
        :return:
        """
        self.cursor.close()
        self.oracle_manager.connect_close()

    def use_lstm(self):
        """
        从数据库中查询数据，开始准备运行LSTM算法
        :return:
        """
        Logger.info('从数据库中查询数据，开始准备运行LSTM算法')

        # 从数据库中查询数据，然后转换为DataFrame类型的对象
        # self.cursor.execute('select t.close_price close, to_char(t.date_, \'yyyyMMdd\') price from stock_transaction_data t where t.code_=\'000007\' order by t.date_ asc')
        # self.cursor.execute(
        #     'select t.open_price, t.highest_price, '
        #     't.lowest_price, t.close_price from stock_transaction_data t where t.code_=\'000007\' order by t.date_ asc')
        self.cursor.execute(
            'select t.open_price, t.highest_price, t.lowest_price, t.close_price from stock_transaction_data_all t where t.code_=\'000004\' order by t.date_ asc')
        stock_close_price_tuple_list = self.cursor.fetchall()
        Logger.info('从数据库中查询到的数据：' + str(stock_close_price_tuple_list))
        # data_frame = DataFrame(stock_code_tuple_list)
        stock_close_price_ndarray = np.array(stock_close_price_tuple_list)
        self.close_cursor_and_connect()

        lstm_manager = LstmManager()
        # lstm_manager.test(data_frame)
        lstm_manager.use_lstm(stock_close_price_ndarray)
